/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    KMeansReduceTask
 *    Copyright (C) 2014 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.distributed;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.NormalizableDistance;
import weka.core.Utils;
import weka.core.stats.ArffSummaryNumericMetric;
import weka.core.stats.NominalStats;

/**
 * Reduce task for k-means clustering. Processes partial cluster summary
 * metadata for a particular run in order to produce a set of Instances that
 * contains new cluster centroids.
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: $
 */
public class KMeansReduceTask implements Serializable {

  /**
   * For serialization
   */
  private static final long serialVersionUID = 6222983145960081251L;

  protected double m_totalWithinClustersError;

  /**
   * Will hold the updated centroids for this run after aggregating the partial
   * clusterings
   */
  protected Instances m_newCentroidsForRun;

  /**
   * This will hold the aggregated summary instances (with summary stats
   * attributes), one for each cluster
   */
  protected List<Instances> m_aggregatedCentroidSummaries;

  /**
   * If run number = 0, then this will hold the some priming data for
   * initializing the ranges of numeric attributes in the case where filters may
   * have transformed/altered the original space.
   */
  protected Instances m_globalDistanceFunctionPrimingData;

  /** The run number */
  protected int m_runNumber;

  /** The iteration number */
  protected int m_iterationNumber;

  /**
   * Reduce the cluster centroid summary metadata instances for a particular run
   * in order to produce a new set of Instances that contains the new cluster
   * centroids for the run. Adds the total within cluster error to the relation
   * name of the instances. If the iteration number is 0 then also generates a
   * two instance data set that can be used for initializing a distance
   * function. The two instances contain global minimum and maximum values for
   * numeric attributes respectively, which is used in the distance function for
   * normalization. This particular dataset is useful when filters (beyond
   * missing values replacement) have been specified for k-means and it is not
   * possible to use the summary stats in the global ARFF header file for
   * initializing the distance function
   * 
   * @param runNumber the current run number
   * @param iterationNumber the current iteration number of k-means
   * @param headerNoSummary the global ARFF header (as computed by the
   *          ArffHeader job on the entire dataset, and having passed through
   *          any preprocessing filters). We need this so that the correct index
   *          for nominal attribute values can be set in the new centroids (map
   *          tasks accumulating summary stats when clustering partitions of the
   *          data may see nominal values in different orders, or not see some
   *          values at all, compared to the global header)
   * @param clusterSummaries a list of cluster summary information. Each inner
   *          list of Instances will have been generated by a map task on a
   *          subset of the data. Each instances object in the list contains the
   *          summary stats for one cluster centroid. Inner lists are in order
   *          of centroid number. A particular Instances entry in a list may be
   *          null - this indicates that the cluster was empty within that
   *          particular map task (i.e. no training instances were assigned to
   *          it)
   * @return an instance of KMeansReduceTask with new centroids and supporting
   *         data computed.
   * @throws DistributedWekaException if a problem occurs
   */
  public KMeansReduceTask reduceClusters(int runNumber, int iterationNumber,
    Instances headerNoSummary, List<List<Instances>> clusterSummaries)
    throws DistributedWekaException {

    m_runNumber = runNumber;
    m_iterationNumber = iterationNumber;
    int numClusters = clusterSummaries.get(0).size();

    // headerNoSummary =
    // CSVToARFFHeaderReduceTask.stripSummaryAtts(headerNoSummary);

    List<List<Instances>> partialsPerCentroid =
      new ArrayList<List<Instances>>();
    for (int i = 0; i < numClusters; i++) {
      partialsPerCentroid.add(new ArrayList<Instances>());
    }

    for (int i = 0; i < clusterSummaries.size(); i++) {
      List<Instances> currentPartial = clusterSummaries.get(i);
      if (currentPartial.size() != numClusters) {
        throw new DistributedWekaException(
          "Each list of centroid summary stats should be "
            + "equal to the number of clusters. Expected " + numClusters
            + " but this list" + " contains " + currentPartial.size());
      }

      for (int j = 0; j < currentPartial.size(); j++) {
        Instances centroidPartial = currentPartial.get(j);
        if (centroidPartial != null) {
          partialsPerCentroid.get(j).add(centroidPartial);
        }
      }
    }

    CSVToARFFHeaderReduceTask reduceTask = new CSVToARFFHeaderReduceTask();
    List<Instances> aggregatedCentroidSummaries = new ArrayList<Instances>();
    m_totalWithinClustersError = 0;
    for (int i = 0; i < partialsPerCentroid.size(); i++) {
      if (partialsPerCentroid.get(i).size() > 0) {
        double clusterError = getErrorsForCluster(partialsPerCentroid.get(i));
        m_totalWithinClustersError += clusterError;

        Instances aggregated = reduceTask.aggregate(partialsPerCentroid.get(i));
        // update the relation name
        aggregated.setRelationName("Stats for centroid " + i + " : "
          + clusterError);
        aggregatedCentroidSummaries.add(aggregated);
      } else {
        // this means that this is now a global empty cluster (i.e.
        // no mappers assigned any training instances to this centroid).
        // So we just drop it.
      }
    }

    m_aggregatedCentroidSummaries = aggregatedCentroidSummaries;

    double[] globalMins = null;
    double[] globalMaxes = null;
    if (iterationNumber == 0) {
      globalMins = new double[headerNoSummary.numAttributes()];
      globalMaxes = new double[headerNoSummary.numAttributes()];
      for (int i = 0; i < headerNoSummary.numAttributes(); i++) {
        if (headerNoSummary.attribute(i).isNumeric()) {
          globalMins[i] = Double.MAX_VALUE;
          globalMaxes[i] = Double.MIN_VALUE;
        } else {
          globalMins[i] = Utils.missingValue();
          globalMaxes[i] = Utils.missingValue();
        }
      }
    }

    // now construct the new centroids
    for (int i = 0; i < aggregatedCentroidSummaries.size(); i++) {
      Instances centroidSummary = aggregatedCentroidSummaries.get(i);
      double[] centerVals = new double[headerNoSummary.numAttributes()];
      for (int j = 0; j < headerNoSummary.numAttributes(); j++) {
        Attribute origAtt = headerNoSummary.attribute(j);
        String name = origAtt.name();
        Attribute summaryAtt =
          centroidSummary
            .attribute(CSVToARFFHeaderMapTask.ARFF_SUMMARY_ATTRIBUTE_PREFIX
              + name);
        if (origAtt.isNumeric()) {
          double nonMissingCountForAtt =
            ArffSummaryNumericMetric.COUNT.valueFromAttribute(summaryAtt);
          double missingCountForAtt =
            ArffSummaryNumericMetric.MISSING.valueFromAttribute(summaryAtt);
          double clusterMeanForAtt =
            ArffSummaryNumericMetric.MEAN.valueFromAttribute(summaryAtt);

          if (missingCountForAtt > nonMissingCountForAtt
            || Utils.isMissingValue(clusterMeanForAtt)) {
            System.err
              .println("********************************* att: "
                + origAtt.name() + " mean: " + clusterMeanForAtt
                + "non-missing: " + nonMissingCountForAtt + " missing: "
                + missingCountForAtt);
            centerVals[j] = Utils.missingValue();
          } else {
            centerVals[j] = clusterMeanForAtt;
          }

          if (iterationNumber == 0) {
            double min =
              ArffSummaryNumericMetric.MIN.valueFromAttribute(summaryAtt);
            double max =
              ArffSummaryNumericMetric.MAX.valueFromAttribute(summaryAtt);
            if (!Utils.isMissingValue(min) && !Double.isInfinite(min)) {
              if (min < globalMins[j]) {
                globalMins[j] = min;
              }
            }
            if (!Utils.isMissingValue(max) && !Double.isInfinite(max)) {
              if (max > globalMaxes[j]) {
                globalMaxes[j] = max;
              }
            }
          }
        } else if (origAtt.isNominal()) {
          NominalStats stats = NominalStats.attributeToStats(summaryAtt);
          // int clusterModeForAttIndex = stats.getMode();
          String clusterModeLabelForAtt = stats.getModeLabel();
          double modeCountForAtt = stats.getCount(clusterModeLabelForAtt);
          double missingCountForAtt = stats.getNumMissing();

          if (missingCountForAtt > modeCountForAtt) {
            centerVals[j] = Utils.missingValue();
          } else {
            // centerVals[j] = clusterModeForAttIndex;
            int mappedIndex =
              headerNoSummary.attribute(j).indexOfValue(clusterModeLabelForAtt);
            if (mappedIndex < 0) {
              throw new DistributedWekaException(
                "Unable to find nominal value '" +
                  clusterModeLabelForAtt + "' in global header attribute '"
                  + headerNoSummary.attribute(j));
            }
            centerVals[j] = mappedIndex;
          }
        } else {
          // this could happen if the user has applied a streamable filter that
          // creates string attributes or something
          throw new DistributedWekaException(
            "k-means can only handle numeric and nominal attributes!");
        }
      }

      // add the new centroid
      headerNoSummary.add(new DenseInstance(1.0, centerVals));
    }

    m_newCentroidsForRun = headerNoSummary;

    // If iteration 0 then compute global priming data for distance functions
    // in the (potentially) filtered space
    if (iterationNumber == 0) {
      m_globalDistanceFunctionPrimingData = new Instances(headerNoSummary, 0);
      m_globalDistanceFunctionPrimingData
        .add(new DenseInstance(1.0, globalMins));
      m_globalDistanceFunctionPrimingData.add(new DenseInstance(1.0,
        globalMaxes));
    }

    return this;
  }

  /**
   * Return the centroids for the run
   * 
   * @return the centroids as a set of instances
   */
  public Instances getCentroidsForRun() {
    return m_newCentroidsForRun;
  }

  /**
   * Get the aggregated summary data for each individual centroid. This is
   * represented as a Instances header with summary meta attributes
   * 
   * @return a list of summary meta data
   */
  public List<Instances> getAggregatedCentroidSummaries() {
    return m_aggregatedCentroidSummaries;
  }

  /**
   * Get the global distance function priming data. This contains global mins
   * and maxes for attributes in the transformed (by any filters) space
   * 
   * @return the global distance function priming data
   */
  public Instances getGlobalDistanceFunctionPrimingData() {
    return m_globalDistanceFunctionPrimingData;
  }

  /**
   * Get the run number
   * 
   * @return the run number
   */
  public int getRunNumber() {
    return m_runNumber;
  }

  /**
   * Get the current iteration number
   * 
   * @return the current iteration number
   */
  public int getIterationNumber() {
    return m_iterationNumber;
  }

  /**
   * Get the total within cluster error for this run
   * 
   * @return the total within cluster error for this run
   */
  public double getTotalWithinClustersError() {
    return m_totalWithinClustersError;
  }

  /**
   * Computes the errors for a particular cluster from a list of partial cluster
   * summary data
   * 
   * @param clusterPartials a list of Instances containing summary meta
   *          attributes
   * @return the total error for the cluster
   * @throws DistributedWekaException if a problem occurs
   */
  protected static double getErrorsForCluster(List<Instances> clusterPartials)
    throws DistributedWekaException {
    double error = 0;

    for (Instances i : clusterPartials) {
      String relationName = i.relationName();
      String[] parts = relationName.split(":");
      if (parts.length != 2) {
        throw new DistributedWekaException(
          "Can't find within cluster error in the "
            + "relation name of a cluster centroid partial stats instances:\n "
            + i.toString());
      }

      try {
        error += Double.parseDouble(parts[1].trim());
      } catch (NumberFormatException e) {
        throw new DistributedWekaException(
          "Unable to parse within cluster error"
            + " from a cluster centroid partial stats instances: \n"
            + i.toString());

      }
    }

    return error;
  }

  /**
   * Utility function to examine the attribute ranges in a bunch of distance
   * functions and return a two instance dataset with the global mins/maxes of
   * numeric attributes set. This can be used to "prime" a distance function.
   * 
   * @param distanceFuncs a list of distance functions (where each potentially
   *          has only seen part of the overall dataset
   * @param headerNoSummary the header of the data that the distance functions
   *          have been seeing
   * @return a priming data set with global min and max values for numeric
   *         attributes
   * @throws DistributedWekaException if a problem occurs
   */
  public static Instances computeDistancePrimingDataFromDistanceFunctions(
    List<NormalizableDistance> distanceFuncs, Instances headerNoSummary)
    throws DistributedWekaException {

    Instances prime = null;
    double[] mins = new double[headerNoSummary.numAttributes()];
    double[] maxes = new double[headerNoSummary.numAttributes()];
    try {
      for (int i = 0; i < headerNoSummary.numAttributes(); i++) {
        if (headerNoSummary.attribute(i).isNumeric()) {
          mins[i] = Double.MAX_VALUE;
          maxes[i] = Double.MIN_VALUE;
        } else {
          mins[i] = Utils.missingValue();
          maxes[i] = Utils.missingValue();
        }
      }
      for (NormalizableDistance d : distanceFuncs) {
        double[][] ranges = d.getRanges();
        for (int i = 0; i < headerNoSummary.numAttributes(); i++) {
          if (ranges[i][NormalizableDistance.R_MIN] < mins[i]) {
            mins[i] = ranges[i][NormalizableDistance.R_MIN];
          }
          if (ranges[i][NormalizableDistance.R_MAX] > maxes[i]) {
            maxes[i] = ranges[i][NormalizableDistance.R_MAX];
          }
        }
      }
    } catch (Exception ex) {
      throw new DistributedWekaException(ex);
    }

    prime = new Instances(headerNoSummary, 2);
    prime.add(new DenseInstance(1.0, mins));
    prime.add(new DenseInstance(1.0, maxes));

    return prime;
  }
}
